one-api/relay/midjourney/relay-mj.go
2024-04-06 15:07:33 +08:00

579 lines
19 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// Author: Calcium-Ion
// GitHub: https://github.com/Calcium-Ion/new-api
// Path: relay/relay-mj.go
package midjourney
import (
"bytes"
"encoding/json"
"io"
"log"
"net/http"
"one-api/common"
"one-api/controller"
"one-api/model"
provider "one-api/providers/midjourney"
"one-api/relay"
"one-api/relay/util"
"one-api/types"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
)
func RelayMidjourneyImage(c *gin.Context) {
taskId := c.Param("id")
midjourneyTask := model.GetByOnlyMJId(taskId)
if midjourneyTask == nil {
c.JSON(400, gin.H{
"error": "midjourney_task_not_found",
})
return
}
resp, err := http.Get(midjourneyTask.ImageUrl)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "http_get_image_failed",
})
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
responseBody, _ := io.ReadAll(resp.Body)
c.JSON(resp.StatusCode, gin.H{
"error": string(responseBody),
})
return
}
// 从Content-Type头获取MIME类型
contentType := resp.Header.Get("Content-Type")
if contentType == "" {
// 如果无法确定内容类型则默认为jpeg
contentType = "image/jpeg"
}
// 设置响应的内容类型
c.Writer.Header().Set("Content-Type", contentType)
// 将图片流式传输到响应体
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
log.Println("Failed to stream image:", err)
}
}
func RelayMidjourneyNotify(c *gin.Context) *provider.MidjourneyResponse {
var midjRequest provider.MidjourneyDto
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByOnlyMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func coverMidjourneyTaskDto(originTask *model.Midjourney) (midjourneyTask provider.MidjourneyDto) {
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = ""
if originTask.ImageUrl != "" {
midjourneyTask.ImageUrl = common.ServerAddress + "/mj/image/" + originTask.MjId
if originTask.Status != "SUCCESS" {
midjourneyTask.ImageUrl += "?rand=" + strconv.FormatInt(time.Now().UnixNano(), 10)
}
}
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
if originTask.Buttons != "" {
var buttons []provider.ActionButton
err := json.Unmarshal([]byte(originTask.Buttons), &buttons)
if err == nil {
midjourneyTask.Buttons = buttons
}
}
if originTask.Properties != "" {
var properties provider.Properties
err := json.Unmarshal([]byte(originTask.Properties), &properties)
if err == nil {
midjourneyTask.Properties = &properties
}
}
return
}
func RelaySwapFace(c *gin.Context) *provider.MidjourneyResponse {
mjProvider, errWithMJ := getMJProviderWithRequest(c, provider.RelayModeMidjourneySwapFace, nil)
if errWithMJ != nil {
return errWithMJ
}
startTime := time.Now().UnixNano() / int64(time.Millisecond)
userId := c.GetInt("id")
var swapFaceRequest provider.SwapFaceRequest
err := common.UnmarshalBodyReusable(c, &swapFaceRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
if swapFaceRequest.SourceBase64 == "" || swapFaceRequest.TargetBase64 == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "sour_base64_and_target_base64_is_required")
}
quotaInstance, errWithOA := getQuota(c, provider.MjActionSwapFace)
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
requestURL := getMjRequestPath(c.Request.URL.String())
mjResp, _, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &mjResp.Response
}
if mjResp.StatusCode == 200 && mjResp.Response.Code == 1 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1000, TotalTokens: 1000})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &mjResp.Response
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: provider.MjActionSwapFace,
MjId: midjResponse.Result,
Prompt: "InsightFace",
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: startTime,
StartTime: time.Now().UnixNano() / int64(time.Millisecond),
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
err = midjourneyTask.Insert()
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "insert_midjourney_task_failed")
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
c.Writer.WriteHeader(mjResp.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTaskImageSeed(c *gin.Context) *provider.MidjourneyResponse {
taskId := c.Param("id")
userId := c.GetInt("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_no_found")
}
mjProvider, errWithMJ := getMJProviderWithChannelId(c, originTask.ChannelId)
if errWithMJ != nil {
return errWithMJ
}
requestURL := getMjRequestPath(c.Request.URL.String())
midjResponseWithStatus, _, err := mjProvider.Send(30, requestURL)
if err != nil {
return &midjResponseWithStatus.Response
}
midjResponse := &midjResponseWithStatus.Response
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
respBody, err := json.Marshal(midjResponse)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "unmarshal_response_body_failed")
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "copy_response_body_failed")
}
return nil
}
func RelayMidjourneyTask(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
userId := c.GetInt("id")
var err error
var respBody []byte
switch relayMode {
case provider.RelayModeMidjourneyTaskFetch:
taskId := c.Param("id")
originTask := model.GetByMJId(userId, taskId)
if originTask == nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
midjourneyTask := coverMidjourneyTaskDto(originTask)
respBody, err = json.Marshal(midjourneyTask)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
case provider.RelayModeMidjourneyTaskFetchByCondition:
var condition = struct {
IDs []string `json:"ids"`
}{}
err = c.BindJSON(&condition)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
var tasks []provider.MidjourneyDto
if len(condition.IDs) != 0 {
originTasks := model.GetByMJIds(userId, condition.IDs)
for _, originTask := range originTasks {
midjourneyTask := coverMidjourneyTaskDto(originTask)
tasks = append(tasks, midjourneyTask)
}
}
if tasks == nil {
tasks = make([]provider.MidjourneyDto, 0)
}
respBody, err = json.Marshal(tasks)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
}
c.Writer.Header().Set("Content-Type", "application/json")
_, err = io.Copy(c.Writer, bytes.NewBuffer(respBody))
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func RelayMidjourneySubmit(c *gin.Context, relayMode int) *provider.MidjourneyResponse {
userId := c.GetInt("id")
consumeQuota := true
var midjRequest provider.MidjourneyRequest
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "bind_request_body_failed")
}
mjProvider, errWithMJ := getMJProviderWithRequest(c, relayMode, &midjRequest)
if errWithMJ != nil {
return errWithMJ
}
if relayMode == provider.RelayModeMidjourneyAction { // midjourney plus需要从customId中获取任务信息
mjErr := CoverPlusActionToNormalAction(&midjRequest)
if mjErr != nil {
return mjErr
}
relayMode = provider.RelayModeMidjourneyChange
}
if relayMode == provider.RelayModeMidjourneyImagine { //绘画任务,此类任务可重复
if midjRequest.Prompt == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "prompt_is_required")
}
midjRequest.Action = provider.MjActionImagine
} else if relayMode == provider.RelayModeMidjourneyDescribe { //按图生文任务,此类任务可重复
midjRequest.Action = provider.MjActionDescribe
} else if relayMode == provider.RelayModeMidjourneyShorten { //缩短任务此类任务可重复plus only
midjRequest.Action = provider.MjActionShorten
} else if relayMode == provider.RelayModeMidjourneyBlend { //绘画任务,此类任务可重复
midjRequest.Action = provider.MjActionBlend
} else if midjRequest.TaskId != "" { //放大、变换任务此类任务如果重复且已有结果远端api会直接返回最终结果
mjId := ""
if relayMode == provider.RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_id_is_required")
} else if midjRequest.Action == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "action_is_required")
} else if midjRequest.Index == 0 {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "index_is_required")
}
//action = midjRequest.Action
mjId = midjRequest.TaskId
} else if relayMode == provider.RelayModeMidjourneySimpleChange {
if midjRequest.Content == "" {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_is_required")
}
params := ConvertSimpleChangeParams(midjRequest.Content)
if params == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "content_parse_failed")
}
mjId = params.TaskId
midjRequest.Action = params.Action
} else if relayMode == provider.RelayModeMidjourneyModal {
//if midjRequest.MaskBase64 == "" {
// return provider.MidjourneyErrorWrapper(provider.MjRequestError, "mask_base64_is_required")
//}
mjId = midjRequest.TaskId
midjRequest.Action = provider.MjActionModal
}
originTask := model.GetByMJId(userId, mjId)
if originTask == nil {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_not_found")
} else if originTask.Status != "SUCCESS" && relayMode != provider.RelayModeMidjourneyModal {
return provider.MidjourneyErrorWrapper(provider.MjRequestError, "task_status_not_success")
} else { //原任务的Status=SUCCESS则可以做放大UPSCALE、变换VARIATION等动作此时必须使用原来的请求地址才能正确处理
mjProvider, errWithMJ = getMJProviderWithChannelId(c, originTask.ChannelId)
if errWithMJ != nil {
return errWithMJ
}
log.Printf("检测到此操作为放大、变换、重绘获取原channel信息: %d", originTask.ChannelId)
}
midjRequest.Prompt = originTask.Prompt
//if channelType == common.ChannelTypeMidjourneyPlus {
// // plus
//} else {
// // 普通版渠道
//
//}
}
if midjRequest.Action == provider.MjActionInPaint || midjRequest.Action == provider.MjActionCustomZoom {
consumeQuota = false
}
//baseURL := common.ChannelBaseURLs[channelType]
requestURL := getMjRequestPath(c.Request.URL.String())
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
quotaInstance, errWithOA := getQuota(c, midjRequest.Action)
if errWithOA != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: errWithOA.Message,
}
}
midjResponseWithStatus, responseBody, err := mjProvider.Send(60, requestURL)
if err != nil {
quotaInstance.Undo(c)
return &midjResponseWithStatus.Response
}
if consumeQuota && midjResponseWithStatus.StatusCode == 200 {
quotaInstance.Consume(c, &types.Usage{CompletionTokens: 0, PromptTokens: 1, TotalTokens: 1})
} else {
quotaInstance.Undo(c)
}
quota := int(quotaInstance.GetInputRatio() * 1000)
midjResponse := &midjResponseWithStatus.Response
// 文档https://github.com/novicezk/midjourney-proxy/blob/main/docs/api.md
//1-提交成功
// 21-任务已存在(处理中或者有结果了) {"code":21,"description":"任务已存在","result":"0741798445574458","properties":{"status":"SUCCESS","imageUrl":"https://xxxx"}}
// 22-排队中 {"code":22,"description":"排队中前面还有1个任务","result":"0741798445574458","properties":{"numberOfQueues":1,"discordInstanceId":"1118138338562560102"}}
// 23-队列已满,请稍后再试 {"code":23,"description":"队列已满,请稍后尝试","result":"14001929738841620","properties":{"discordInstanceId":"1118138338562560102"}}
// 24-prompt包含敏感词 {"code":24,"description":"可能包含敏感词","properties":{"promptEn":"nude body","bannedWord":"nude"}}
// other: 提交错误description为错误描述
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: time.Now().UnixNano() / int64(time.Millisecond),
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
ChannelId: c.GetInt("channel_id"),
Quota: quota,
}
if midjResponse.Code != 1 && midjResponse.Code != 21 && midjResponse.Code != 22 {
//非1-提交成功,21-任务已存在和22-排队中,则记录错误原因
midjourneyTask.FailReason = midjResponse.Description
consumeQuota = false
}
if midjResponse.Code == 21 { //21-任务已存在(处理中或者有结果了)
// 将 properties 转换为一个 map
properties, ok := midjResponse.Properties.(map[string]interface{})
if ok {
imageUrl, ok1 := properties["imageUrl"].(string)
status, ok2 := properties["status"].(string)
if ok1 && ok2 {
midjourneyTask.ImageUrl = imageUrl
midjourneyTask.Status = status
if status == "SUCCESS" {
midjourneyTask.Progress = "100%"
midjourneyTask.StartTime = time.Now().UnixNano() / int64(time.Millisecond)
midjourneyTask.FinishTime = time.Now().UnixNano() / int64(time.Millisecond)
midjResponse.Code = 1
}
}
}
//修改返回值
if midjRequest.Action != provider.MjActionInPaint && midjRequest.Action != provider.MjActionCustomZoom {
newBody := strings.Replace(string(responseBody), `"code":21`, `"code":1`, -1)
responseBody = []byte(newBody)
}
}
err = midjourneyTask.Insert()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
// 开始激活任务
controller.ActivateUpdateMidjourneyTaskBulk()
if midjResponse.Code == 22 { //22-排队中,说明任务已存在
//修改返回值
newBody := strings.Replace(string(responseBody), `"code":22`, `"code":1`, -1)
responseBody = []byte(newBody)
}
//resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
bodyReader := io.NopCloser(bytes.NewBuffer(responseBody))
//for k, v := range resp.Header {
// c.Writer.Header().Set(k, v[0])
//}
c.Writer.WriteHeader(midjResponseWithStatus.StatusCode)
_, err = io.Copy(c.Writer, bodyReader)
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = bodyReader.Close()
if err != nil {
return &provider.MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}
func getMjRequestPath(path string) string {
requestURL := path
if strings.Contains(requestURL, "/mj-") {
urls := strings.Split(requestURL, "/mj/")
if len(urls) < 2 {
return requestURL
}
requestURL = "/mj/" + urls[1]
}
return requestURL
}
func getQuota(c *gin.Context, action string) (*util.Quota, *types.OpenAIErrorWithStatusCode) {
modelName := CoverActionToModelName(action)
return util.NewQuota(c, modelName, 1000)
}
func getMJProviderWithRequest(c *gin.Context, relayMode int, request *provider.MidjourneyRequest) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
midjourneyModel, mjErr, _ := GetMjRequestModel(relayMode, request)
if mjErr != nil {
return nil, MidjourneyErrorFromInternal(mjErr.Code, mjErr.Description)
}
if midjourneyModel == "" {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法解析模型")
}
return getMJProvider(c, midjourneyModel)
}
func getMJProviderWithChannelId(c *gin.Context, channel_id int) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
c.Set("specific_channel_id", channel_id)
return getMJProvider(c, "")
}
func getMJProvider(c *gin.Context, modelName string) (*provider.MidjourneyProvider, *provider.MidjourneyResponse) {
baseProvider, _, err := relay.GetProvider(c, modelName)
if err != nil {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无法获取provider:"+err.Error())
}
mjProvider, ok := baseProvider.(*provider.MidjourneyProvider)
if !ok {
return nil, MidjourneyErrorFromInternal(provider.MjErrorUnknown, "无效的请求, 无法获取midjourney provider")
}
return mjProvider, nil
}